Skip to content

feat: Add GLA (Gated Linear Attention) Forward Operator (L2)#2

Open
superAngGao wants to merge 8 commits intomainfrom
feat/gla-fwd
Open

feat: Add GLA (Gated Linear Attention) Forward Operator (L2)#2
superAngGao wants to merge 8 commits intomainfrom
feat/gla-fwd

Conversation

@superAngGao
Copy link
Owner

Summary

Implements Gated Linear Attention (GLA) forward pass as a new L2 operator (Kernel + Op).

Closes tile-ai#213

Algorithm

Chunked GLA forward in 4 stages:

  1. Gate cumsum (PyTorch): within-chunk prefix sum of log-space gates
  2. Hidden state recurrence (PyTorch): inter-chunk h [B,NT,H,K,V] with gated decay
  3. Intra-chunk attention (TileLang): causal A [B,T,H,BT] with gated QK
  4. Output (TileLang): o = scale*(q*exp(g_cs))@h + A@v

Files Changed

File Description
tileops/kernels/gla/gla_fwd.py GLAFwdKernel — TileLang stages 3 & 4, sm90a
tileops/kernels/gla/__init__.py Kernel package export
tileops/ops/gla.py GLAFwdOp — Op wrapper
tileops/ops/__init__.py Register GLAFwdOp
tests/ops/test_gla.py 7 test cases (fp16 + bf16, ±initial_state)

Test Results

7/7 passed — fp16, bf16, dim_k=64/128, dim_v=64/128, chunk_size=32/64, with/without initial_state

Reference

https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py

Checklist

  • Kernel implemented (TileLang, sm90a)
  • Op wrapper with stable API
  • 7 correctness tests passing
  • __init__.py exports synchronized
  • Benchmark (out of scope for L2 — to be added separately)

@superAngGao superAngGao force-pushed the feat/gla-fwd branch 4 times, most recently from fc3c7ab to 5f1e0c7 Compare February 27, 2026 06:53
superAngGao and others added 8 commits February 27, 2026 16:54
Implements chunked GLA forward pass with:
- Stage 1+2 (PyTorch): within-chunk gate cumsum + inter-chunk hidden state recurrence
- Stage 3 (TileLang): intra-chunk causal attention matrix A [B, T, H, BT]
- Stage 4 (TileLang): output combining inter-chunk and intra-chunk contributions

Files added:
- tileops/kernels/gla/gla_fwd.py  -- GLAFwdKernel (sm90a)
- tileops/kernels/gla/__init__.py
- tileops/ops/gla.py              -- GLAFwdOp
- tests/ops/test_gla.py           -- 7 test cases (fp16 + bf16, with/without initial_state)

Closes tile-ai#213
Reference: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add seq_len % chunk_size == 0 assertion in GLAFwdOp to prevent OOB
  writes in TileLang kernels on non-divisible sequence lengths
- Cast k/v to float32 per-chunk in GLAFwdKernel.forward to reduce peak
  memory usage
- Fix k_adj formula in ref_gla_fwd to use log-space subtraction
  (matching GLAFwdKernel) instead of division with clamp
- Add test_gla_fwd_non_divisible_seq_len to verify the assertion fires
- Add skill.md files for create-new-kernel, create-new-op,
  create-new-op-test, creating-pull-request, migrating-new-op

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…on skill, add YAML frontmatter and auto-invoke to all skills

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…entions

- Single @T.prim_func with 4 @T.macro stages in one T.Serial(num_chunks) loop
- Stages run in order 1→3→4→2 so stage4 reads pre-decay h_s before stage2 updates it
- Hoist all shared buffers into _main and pass as parameters to eliminate duplicate allocations (stays within 232448 byte optin limit)
- Move shape lists inside _gla_fwd_func so outer closure only captures serializable scalars (fixes autotuner assertion)
- Add self.kernel assignment in __init__ to support autotune
- Fix custom_op namespace to top:: and add autotune_configs
- forward() only allocates buffers and calls wrapper; no PyTorch compute

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: Add GLA (Gated Linear Attention) Forward Operator (L2)

1 participant